{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rNYRA6k8Qfyo"
      },
      "source": [
        "# Scenario Data Loading\n",
        "\n",
        "This tutorial demonstrates how to load scenario data from the Waymo Open Motion Dataset (WOMD) using the Waymax dataloader."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MtgRcYqmtMwD"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "import numpy as np\n",
        "import mediapy\n",
        "from tqdm import tqdm\n",
        "import dataclasses\n",
        "\n",
        "from waymax import config as _config\n",
        "from waymax import dataloader\n",
        "from waymax import datatypes\n",
        "from waymax import visualization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0o2sAapxRMAT"
      },
      "source": [
        "\n",
        "We first create a dataset config, using the default configs provided in the `waymax.config` module. In particular, `config.WOD_1_1_0_TRAINING` is a pre-defined configuration that points to version 1.1.0 of the Waymo Open Dataset.\n",
        "\n",
        "The data config contains a number of options to configure how and where the dataset is loaded from. By default, the `WOD_1_1_0_TRAINING` loads up to 128 objects (e.g. vehicles, pedestrians) per scenario. Here, we can save memory and compute by loading only the first 32 objects stored in the scenario.\n",
        "\n",
        "We use the `dataloader.simulator_state_generator` function to create an iterator\n",
        "through Open Motion Dataset scenarios. Calling next on the iterator will retrieve the first scenario in the dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dkJwTuSLr0gh"
      },
      "outputs": [],
      "source": [
        "config = dataclasses.replace(_config.WOD_1_1_0_TRAINING, max_num_objects=32)\n",
        "data_iter = dataloader.simulator_state_generator(config=config)\n",
        "scenario = next(data_iter)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q1xyeYpLR8J6"
      },
      "source": [
        "Next, we can plot the initial state of this scenario. We use a matplotlib-based visualization available in the `waymax.visualization` package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OY3-OOArsFcU"
      },
      "outputs": [],
      "source": [
        "# Using logged trajectory\n",
        "img = visualization.plot_simulator_state(scenario, use_log_traj=True)\n",
        "mediapy.show_image(img)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H0Z15epRSC23"
      },
      "source": [
        "The Waymo Open Motion Dataset consists of 9-second trajectory snippets. We can visualize the entire logged trajectory as a video as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "06SjvXdRrV3N"
      },
      "outputs": [],
      "source": [
        "imgs = []\n",
        "\n",
        "state = scenario\n",
        "for _ in range(scenario.remaining_timesteps):\n",
        "  state = datatypes.update_state_by_log(state, num_steps=1)\n",
        "  imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))\n",
        "\n",
        "mediapy.show_video(imgs, fps=10)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "14w5MbrMNLsOsLuD5kXy5-rrNO3ZgsHat",
          "timestamp": 1678404744504
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
